Guide to Implementing Custom loss_func
When implementing a custom loss_func in ROLL, the most critical aspects are how the loss is aggregated and how loss_scale is handled. Mishandling these two points can cause the final computed loss or gradients to deviate from the result that would be obtained by performing a single forward pass over the entire global batch, thereby introducing training bias—especially severe in complex training scenarios involving data parallelism (DP) + gradient accumulation (GA) + sequence packing.
1. Common Loss Aggregation Strategies
Consider a global batch containing sequences. Let the length of the -th sequence be , with a per-token mask indicating whether position participates in loss computation. The number of valid tokens is:
Let denote the per-token loss at position of sequence (e.g., NLL, CE, KL divergence, policy loss, etc.).
1.1 Token-level Loss (token-mean)
Compute the average loss over all valid tokens in the global batch:
Property: Each token has equal weight; longer sequences contribute more due to having more valid tokens.
1.2 Sequence-level Loss (seq-mean)
First aggregate within each sequence, then average across sequences. ROLL commonly uses two variants:
(a) seq-mean-token-sum
Sum losses over tokens within each sequence, then average across sequences:
(b) seq-mean-token-mean
Average losses over tokens within each sequence, then average across sequences:
Property: Each sequence has equal weight, avoiding bias due to sequence length differences.
2. Micro-batch Partitioning in Distributed Training
In practice, a single global training step typically involves:
- Data Parallelism (DP): The global batch is split across multiple DP ranks;
- Gradient Accumulation (GA): Each rank further splits its data into multiple micro-batches, processed sequentially;
- Sequence Packing: To reduce padding and improve GPU utilization, multiple samples are concatenated into fixed-length packed sequences.
Let:
- DP world size be ,
- Gradient accumulation steps be ,
- Then the total number of micro-batches per global step is .
Denote the set of samples in the -th micro-batch as . Its number of valid tokens is:
The number of sequences (samples) in this micro-batch is , satisfying: